import numpy as np
import copy
import random
from scipy.signal import correlate2d, fftconvolve
import uncertainties
from errorprop import ErrorVal


def findNeighbors2(x, y, m, mX, mY):
    neighbors = []
    if x + 1 < mX:
        if m[x + 1, y] != 0:
            neighbors.append([x + 1, y])
        elif m[x, y] == 0:
            neighbors.append([x + 1, y])
    if x - 1 >= 0:
        if m[x - 1, y] != 0:
            neighbors.append([x - 1, y])
        elif m[x, y] == 0:
            neighbors.append([x - 1, y])
    if y + 1 < mY:
        if m[x, y + 1] != 0:
            neighbors.append([x, y + 1])
        elif m[x, y] == 0:
            neighbors.append([x, y + 1])
    if y - 1 >= 0:
        if m[x, y - 1] != 0:
            neighbors.append([x, y - 1])
        elif m[x, y] == 0:
            neighbors.append([x, y - 1])
    return neighbors


def findSite(m, holeEndingPos):
    mX = m.shape[0]
    mY = mX
    centerX = np.random.randint(mX)
    centerY = np.random.randint(mY)
    if [centerX, centerY] in holeEndingPos or m[centerX, centerY] == 0:  # m==0 means site is outside of ROI
        centerX, centerY = findSite(m, holeEndingPos)
    return centerX, centerY


def setStringLength(distrCum):
    r = random.random()
    stringLength = 0
    low = 0
    for i in distrCum:
        up = i
        if low < r < up:
            break
        low = up
        stringLength = stringLength + 1
    return stringLength


def new_string(m, ref, minL, maxL, holeEndingPos, distrCum):
    mX = m.shape[0]
    mY = mX
    mOld = np.copy(m)
    movesList = []
    string = []

    centerX, centerY = findSite(m, holeEndingPos)
    while [centerX, centerY] in holeEndingPos:  # hole should not start on other hole
        centerX, centerY = findSite(m, holeEndingPos)
    m[centerX, centerY] = -1

    # choose length with probab distr: make longest 2 lengths less probable!
    length = setStringLength(distrCum)

    oldHolePosition = [centerX, centerY]
    string.append([centerX, centerY])
    oldOldPos = oldHolePosition  # initiate the position 2 moves ago as start position

    while length > 0:
        neighbors = findNeighbors2(oldHolePosition[0], oldHolePosition[1], m, mX,
                                   mY)  # findN2 allows to go from empty to empty site, but not from occupied to empty!
        neighbors = [n for n in neighbors if (n[0] != oldOldPos[0] or n[1] != oldOldPos[
            1]) and n not in holeEndingPos]  # exclude positions of already existing perturbations
        if len(neighbors) == 0:
            m, holeEndingPos, movesList, string = new_string(mOld, ref, minL, maxL, holeEndingPos, distrCum)
            break
        else:
            nextSite = np.random.randint(0, len(neighbors))
            holeX = neighbors[nextSite][0]
            holeY = neighbors[nextSite][1]
            m[oldHolePosition[0], oldHolePosition[1]] = np.copy(m[holeX, holeY])
            m[holeX, holeY] = -1
            ######### get executed move-label ###########
            if oldHolePosition[0] < holeX:  # hole moved to the right
                move = 0
            elif oldHolePosition[0] > holeX:
                move = 1
            elif oldHolePosition[1] < holeY:  # hole moved down
                move = 2
            else:
                move = 3
            movesList.append(move)
            string.append([holeX, holeY])
            oldOldPos = copy.copy(oldHolePosition)
            oldHolePosition = [holeX, holeY]
            length = length - 1
    if length == 0:
        holeEndingPos.append(oldHolePosition)

    return m, holeEndingPos, movesList, string


def findNeighbors(x, y, m, mX, mY):
    neighbors = []
    if x + 1 < mX:
        if m[x + 1, y] != 0:
            neighbors.append([x + 1, y])
    if x - 1 >= 0:
        if m[x - 1, y] != 0:
            neighbors.append([x - 1, y])
    if y + 1 < mY:
        if m[x, y + 1] != 0:
            neighbors.append([x, y + 1])
    if y - 1 >= 0:
        if m[x, y - 1] != 0:
            neighbors.append([x, y - 1])
    return neighbors


def continueString(strings, finishedStrings, mDiff, m):
    mX = mDiff.shape[0]
    string = strings[0]
    lastSite = string[-1]
    neighbors = findNeighbors(lastSite[0], lastSite[1], m, mX, mX)
    neighbors = [n for n in neighbors if mDiff[n[0], n[1]] == 2 and n not in string]
    for neighbor in neighbors:
        newString = string + [neighbor]
        strings.append(newString)
    if len(neighbors) == 0:
        finishedStrings.append(string)
    del strings[0]
    return strings, finishedStrings


def possibleString(redArea, m, mDiff):
    redArea = list(redArea)
    darkFlip = 0
    # determine possible hole positions: within or neighboring sites to red area with m[x,y]=0
    possibleHoles_tmp = [pos for pos in redArea if m[pos[0], pos[1]] == -1]
    # get list with every site only once
    possibleHoles = []
    for hole in possibleHoles_tmp:
        if hole not in possibleHoles:
            possibleHoles.append(hole)

    finishedStrings = []
    for hole in possibleHoles:
        strings = [[hole]]
        while len(strings) > 0:
            strings, finishedStrings = continueString(strings, finishedStrings, mDiff, m)
            if len(strings) > 1000:
                finishedStrings = finishedStrings + strings
                strings = []
    lengths = []
    aspectRatios = []
    for string in finishedStrings:
        lengths.append(len(string) - 1)
        coords = np.transpose(string)
        ar = (max(coords[0]) - min(coords[0]),
              max(coords[1]) - min(coords[1]))
        aspectRatios.append(ar)
    if len(lengths) == 0:
        if len(possibleHoles) == 0 and len(redArea) == 1:
            darkFlip = 1
            site = redArea[0]
            m[site[0], site[1]] = 2 * (abs(1 - m[site[0], site[1]]) - 0.5)
            mDiff[site[0], site[1]] = 0
            longestLength = 0
            aspectRatio = (0, 0)
        else:
            longestLength = 100
            aspectRatio = (100, 0)
            for site in redArea:
                m[site[0], site[1]] = -m[site[0], site[1]]
                mDiff[site[0], site[1]] = 0
    else:
        indices = sorted(range(len(lengths)), key=lambda k: -lengths[k])
        indicesar = sorted(range(len(aspectRatios)), key=lambda k: -lengths[k])
        strings = [finishedStrings[i] for i in indices]
        lengths = [lengths[i] for i in indices]
        aspectRatios = [aspectRatios[i] for i in indicesar]
        for site in strings[0]:
            m[site[0], site[1]] = -m[site[0], site[1]]
            mDiff[site[0], site[1]] = 0
        longestLength = lengths[0]
        aspectRatio = aspectRatios[0]

    return longestLength, aspectRatio, darkFlip, m, mDiff


def findRedArea(red, redArea, redSites, m, mX):
    redArea = redArea + [red]
    neighbors = findNeighbors(red[0], red[1], m, mX, mX)
    redNeighbors = [n for n in neighbors if n in redSites and n not in redArea]
    for redNeighbor in redNeighbors:
        redArea = redArea + findRedArea(redNeighbor, redArea, redSites, m, mX)

    # get list with every site only once
    redAreaUnique = []
    for redSite in redArea:
        if redSite not in redAreaUnique:
            redAreaUnique.append(redSite)

    return redAreaUnique


def getcorr(ams, type, d=None):
    min_samps = 100
    min_samps_azi = 1
    min_radius_azi = 6
    corels = []

    hsize = len(ams[0][0]) / 2

    # Iterate through noblow, blow1, blow2
    for i, scan in enumerate(ams):
        cor = dict()
        j_sub = []

        # add up pp and hh correlations from each image within blowout group
        for j, am in enumerate(scan):
            # am has 1=particle, -1=hole, 0=outside ROI
            arr = np.copy(am)
            arr[arr == -1] = 0

            # This wants 1=particle, 0=hole OR outside ROI
            pp = np.rint(fftconvolve(arr, arr[::-1, ::-1])).astype(dtype='int')
            arr = -np.copy(am)
            arr[arr == -1] = 0
            # This wants 1=hole, 0=particle OR outside ROI
            hh = np.rint(fftconvolve(arr, arr[::-1, ::-1])).astype(dtype='int')

            cor['pp'] = cor['pp'] + pp if 'pp' in cor else pp
            cor['hh'] = cor['hh'] + hh if 'hh' in cor else hh
            j_sub += [j]

        arr = np.copy(am)
        arr[arr != 0] = 1

        total = correlate2d(arr, arr)  # calculate max possible value for correlators
        cor['to'] = total * float(len(j_sub))

        masked_to = np.copy(cor['to'])
        masked_to[masked_to < min_samps] = 1000000.

        cor['pp'] = ErrorVal(cor['pp'] / masked_to,
                             np.sqrt(cor['pp'] / masked_to * (1 - cor['pp'] / masked_to) / masked_to))
        cor['hh'] = ErrorVal(cor['hh'] / masked_to,
                             np.sqrt(cor['hh'] / masked_to * (1 - cor['hh'] / masked_to) / masked_to))

        cor['pp'].value[cor['to'] < min_samps] = 0
        cor['pp'].error[cor['to'] < min_samps] = 0
        cor['hh'].value[cor['to'] < min_samps] = 1
        cor['hh'].error[cor['to'] < min_samps] = 0

        cor['ph'] = 1 - cor['pp'] - cor['hh']

        cor['dens'] = 0.5 + 0.5 * cor['pp'] - 0.5 * cor['hh']
        cor['pol'] = 0.5 * cor['pp'] - 0.5 * cor['hh']

        corels.append(cor)

    n = corels[0]['to']

    def sq(x, n):
        return (x ** 2) * n / (n - 1) - x / (n - 1)

    # def sqp(x, n):
    #     return (x**2)*n/(n-1)-1/(n-1)

    if type == "ss":
        b0 = corels[0]['pp'].pack()
        b1 = corels[1]['pp'].pack()
        b2 = corels[2]['pp'].pack()
        d0 = corels[0]['dens'].pack()
        d1 = corels[1]['dens'].pack()
        d2 = corels[2]['dens'].pack()

        def calc(b0, b1, b2, d0, d1, d2, n):
            if n == 0: return uncertainties.ufloat(0, 0)
            cov0 = b0.value * (1 - d0.value) / n
            (b0, d0) = uncertainties.correlated_values([b0.value, d0.value],
                                                       [[b0.error ** 2, cov0], [cov0, d0.error ** 2]])
            cov1 = b1.value * (1 - d1.value) / n
            (b1, d1) = uncertainties.correlated_values([b1.value, d1.value],
                                                       [[b1.error ** 2, cov1], [cov1, d1.error ** 2]])
            cov2 = b2.value * (1 - d2.value) / n
            (b2, d2) = uncertainties.correlated_values([b2.value, d2.value],
                                                       [[b2.error ** 2, cov2], [cov2, d2.error ** 2]])
            return 2 * (b1 - sq(d1, n)) + 2 * (b2 - sq(d2, n)) - (b0 - sq(d0, n))

        res = np.vectorize(calc)(b0, b1, b2, d0, d1, d2, n)

    elif type == 'pp':
        # ylabel = 'particle-particle correlator'

        b0 = corels[0]['pp'].pack()
        d0 = corels[0]['dens'].pack()

        def calc(b0, d0, n):
            if n == 0: return uncertainties.ufloat(0, 0)
            b0 = uncertainties.ufloat(b0.value, b0.error)
            d0 = uncertainties.ufloat(d0.value, d0.error)
            return b0 - sq(d0, n)

        res = np.vectorize(calc)(b0, d0, n)

    elif type == 'hh':
        # ylabel = 'hole-hole correlator'

        b0 = corels[0]['hh'].pack()
        d0 = corels[0]['dens'].pack()

        def calc(b0, d0, n):
            if n == 0: return uncertainties.ufloat(0, 0)
            b0 = uncertainties.ufloat(b0.value, b0.error)
            d0 = uncertainties.ufloat(d0.value, d0.error)
            return b0 - sq((1 - d0), n)

        res = np.vectorize(calc)(b0, d0, n)

    elif type == 'ph':
        # ylabel = 'particle-hole correlator'

        b0 = corels[0]['ph'].pack()
        d0 = corels[0]['dens'].pack()

        def calc(b0, d0, n):
            if n == 0: return uncertainties.ufloat(0, 0)
            b0 = uncertainties.ufloat(b0.value, b0.error)
            d0 = uncertainties.ufloat(d0.value, d0.error)
            return b0 - 2 * d0 + 2 * sq(d0, n)

        res = np.vectorize(calc)(b0, d0, n)

    elif type == 'g2':
        b0 = corels[0]['hh'].pack()
        d0 = corels[0]['dens'].pack()

        def calc(b0, d0, n):
            if n == 0: return uncertainties.ufloat(0, 0)

            b0 = uncertainties.ufloat(b0.value, b0.error)
            d0 = uncertainties.ufloat(d0.value, d0.error)
            return b0 / sq((1 - d0), n)

        res = np.vectorize(calc)(b0, d0, n)

    elif type == 'g2n':
        assert d is not None
        if d < 0:
            print "WARNING, NEGATIVE DOPING VALUE; using absolute value for d = ", d
            d = abs(d)

        b0 = corels[0]['hh'].pack()
        d0 = corels[0]['dens'].pack()

        def calc(b0, d0, n):
            if n == 0: return uncertainties.ufloat(0, 0)

            b0 = uncertainties.ufloat(b0.value, b0.error)
            d0 = uncertainties.ufloat(d0.value, d0.error)

            return (b0 - sq((1 - d0), n)) / d ** 2 + 1.0

        res = np.vectorize(calc)(b0, d0, n)

    def getn(a):
        return a.n

    def gets(a):
        return a.s

    cor_map, err = np.vectorize(getn)(res), np.vectorize(gets)(res)
    cor_map[corels[0]['to'] < min_samps] = -2

    xs = []
    ys = []
    es = []
    for i in range((2 * hsize - 1), 2 * (2 * hsize - 1)):
        for j in range(cor_map.shape[1]):
            if type == 'g2n':
                if not (i == (2 * hsize - 1) and j >= (2 * hsize - 1)):
                    xs.append(((i - (2 * hsize - 1)) ** 2 + (j - (2 * hsize - 1)) ** 2) ** 0.5)
                    ys.append(cor_map[i][j])
                    es.append(err[i][j])
            elif type != 'g2':
                if np.abs(cor_map[i][j]) <= 1 and not (i == (2 * hsize - 1) and j >= (2 * hsize - 1)):
                    xs.append(((i - (2 * hsize - 1)) ** 2 + (j - (2 * hsize - 1)) ** 2) ** 0.5)
                    if type == "ss":  # Only do sign correction for spin-spin correlator
                        ys.append(cor_map[i][j] * (-1) ** (i + j))
                    else:
                        ys.append(cor_map[i][j])
                    es.append(err[i][j])
            else:
                if cor_map[i][j] > -2 and not (i == (2 * hsize - 1) and j >= (2 * hsize - 1)):
                    xs.append(((i - (2 * hsize - 1)) ** 2 + (j - (2 * hsize - 1)) ** 2) ** 0.5)
                    if type == "ss":  # Only do sign correction for spin-spin correlator
                        ys.append(cor_map[i][j] * (-1) ** (i + j))
                    else:
                        ys.append(cor_map[i][j])
                    es.append(err[i][j])

    xsn, ysn, esn = zip(*sorted(zip(xs, ys, es), key=lambda (x, y, e): x))

    xs = []
    ys = []
    es = []
    tot = []
    errs = []
    cur = None

    for i in range(len(xsn)):
        if cur is None:
            cur = xsn[i]
        if i < len(xsn) - 1 and abs(xsn[i + 1] - cur) < 0.01:  # 0.333:
            tot += [ysn[i]]
            errs += [esn[i]] if esn[i] > 0 else [99999]
        else:
            tot += [ysn[i]]
            errs += [esn[i]] if esn[i] > 0 else [99999]
            if len(tot) >= min_samps_azi or cur <= min_radius_azi:
                xs += [cur]
                ys += [np.sum(np.array(tot) * np.array(errs) ** -2) / np.sum(np.array(errs) ** -2)]
                es += [1 / np.sqrt(np.sum(np.array(errs) ** -2))]

            tot = []
            errs = []
            cur = None

    return xs, ys, es
